import os
import sys
sys.path.insert(0, os.getcwd())
import wandb
import os, sys
sys.path.insert(0, os.getcwd())
from analysis.base import Gen_Analysis, np_softmax
import torch
import numpy as np
import seaborn as sns
import torch.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import os
import csv
from scipy import linalg
from tqdm import tqdm
import pickle
import pathlib


#sns.set(style="whitegrid")


class FIDAnalysis(Gen_Analysis):

    def __init__(self, original, args, rootdir="./out"):
        super().__init__(original, rootdir=rootdir,
                         args=args)

        self.results_exp = self.results_all[args.exp]
        self.stat_file = os.path.join(rootdir, "{}.pkl".format(args.dataset));

        if os.path.exists(self.stat_file) :
            with open(self.stat_file, "rb") as f:
                stats = pickle.load(f);
        else :
            stats = self._collect_data(self.model, self.original);

            pathlib.Path(rootdir).mkdir(parents=True, exist_ok=True)
            with open(self.stat_file, "wb") as f :
                pickle.dump(stats, f)

        self.results_exp["orig"] = stats

    @torch.no_grad()
    def _collect_data(self, model, loader, desc=None):
        model.eval()
        #data_dict = {"features": []}
        features = []
        for x in tqdm(loader, desc="Collecting features for Analysis" if desc is None else desc):
            if type(x) == list:
                x = x[0]

            x = x.cuda(non_blocking=True)
            # y = y.cuda(non_blocking=True)

            pred = model(x)[0]

            # If model output is not scalar, apply global spatial average pooling.
            # This happens if you choose a dimensionality not equal 2048.
            if pred.size(2) != 1 or pred.size(3) != 1:
                pred = F.adaptive_avg_pool2d(pred, output_size=(1, 1))

            pred = pred.squeeze(3).squeeze(2).cpu().numpy()
            features.extend(pred)

        mu, sigma = self._compute_staistics(np.stack(features));

        data_dict = {"mu" : mu,
                     "cov" : sigma}

        return data_dict

    def _compute_staistics(self, activation):
        mu = np.mean(activation, axis=0);
        sigma = np.cov(activation, rowvar=False)

        return mu, sigma

    def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
        """Numpy implementation of the Frechet Distance.
        The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
        and X_2 ~ N(mu_2, C_2) is
                d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
        Stable version by Dougal J. Sutherland.
        Params:
        -- mu1   : Numpy array containing the activations of a layer of the
                   inception net (like returned by the function 'get_predictions')
                   for generated samples.
        -- mu2   : The sample mean over activations, precalculated on an
                   representative data set.
        -- sigma1: The covariance matrix over activations for generated samples.
        -- sigma2: The covariance matrix over activations, precalculated on an
                   representative data set.
        Returns:
        --   : The Frechet Distance.
        """

        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)

        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        #sigma1 = sigma1.astype(np.float32)
        #sigma2 = sigma2.astype(np.float32)

        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'

        diff = mu1 - mu2

        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                   'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)

        return (diff.dot(diff) + np.trace(sigma1)
                + np.trace(sigma2) - 2 * tr_covmean)

    def _compute_fids(self, stat1, stat2):
        mu1,cov1 = stat1;
        mu2,cov2 = stat2

        return self.calculate_frechet_distance(mu1, cov1, mu2,cov2);


    def get_xys(self, data, desc="generated"):
        dloader = self._data_to_loader(data);
        self.results_exp[desc] = self._collect_data(
            model=self.model, loader=dloader, desc=desc)

        mu_ori = self.results_exp["orig"]["mu"];
        cov_ori = self.results_exp["orig"]["cov"];

        mu_gen = self.results_exp[desc]["mu"];
        cov_gen = self.results_exp[desc]["cov"];

        fid_score = self._compute_fids((mu_ori, cov_ori), (mu_gen, cov_gen))
        return fid_score


    def plot(self, rootdir=None):

        mu_ori = self.results_exp["orig"]["mu"];
        cov_ori = self.results_exp["orig"]["cov"];

        for key in self.results_exp.keys() :
            if key == "orig" :
                continue;

            mu_gen = self.results_exp[key]["mu"];
            cov_gen = self.results_exp[key]["cov"];

            fid_score = self._compute_fids((mu_ori, cov_ori), (mu_gen, cov_gen))

            print("FID Score {}-{} : {}".format("orig", key, fid_score))

    def _file_name(self, postfix, rootdir=None):
        return os.path.join(self.rootdir if rootdir is None else rootdir, "fid {}".format(postfix))


    def print(self,):
        return # nothing to print

    def to_csv(self, rootdir=None):
        return # nothing to print

if __name__ == '__main__':
    from datatool.datatool import get_dl_tr, data_path
    from torchvision import datasets
    from torch.utils.data import DataLoader
    import argparse

    from torchvision import transforms
    from tools.utils import init_distributed_mode

    parser = argparse.ArgumentParser()
    parser.add_argument('--num-workers', type=int, default=0,
                        help='bacth_size_per_gpu')
    parser.add_argument('--bsz', type=int, default=32,
                        help='bacth_size_per_gpu')
    parser.add_argument('--debug', type=int, default=1,
                        help='debug or not')
    parser.add_argument('--x_sigma', type=float, default=0.001,
                        help='x_sigma')
    parser.add_argument('--exp', type=str, default="test",
                        help='CIFAR10 ori')
    parser.add_argument('--dataset', type=str, default='cifar10',
                        help='The dataset to use for training)')

    parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
            distributed training; see https://pytorch.org/docs/stable/distributed.html""")
    parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
    parser.add_argument("--device" , type=str, default="cuda")

    args = parser.parse_args();

    dt_cifar = datasets.CIFAR10(root=data_path["cifar10"] + 'train/', train=True,
                             download=True, transform=transforms.Compose([transforms.ToTensor()
                            ]))

    dl_cifar = torch.utils.data.DataLoader(
        dt_cifar,
        batch_size=args.bsz,
        num_workers=args.num_workers
    )

    fid_anl = FIDAnalysis(dl_cifar, rootdir="./out", args=args)
    score = fid_anl.get_xys((dt_cifar.data[:200],), desc="cifar2") # 바로 리턴


    fid_anl = FIDAnalysis((dt_cifar.data,), rootdir="./out", args=args, overwrite=False)
    score = fid_anl.get_xys((dt_cifar.data,), desc="cifar2") # 바로 리턴

    print(score)
    #fid_anl.get_xys((dt_svhn.data[3000:6000],), desc="svhn2")

    fid_anl.plot()
